from model.satcl_modules import Linear, Conv2d, conv3x3
from model.satcl_module_utils import compute_conv_output_size
from torch.nn.functional import relu, avg_pool2d

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict
from copy import deepcopy


# Define MLP model
class MLPNet(nn.Module):
    def __init__(self, args):
        self.args = args
        super(MLPNet, self).__init__()
        self.act = OrderedDict()
        self.lin1 = Linear(784, self.args.n_hidden, bias=False)
        self.lin2 = Linear(self.args.n_hidden, self.args.n_hidden, bias=False)
        self.fc1 = Linear(self.args.n_hidden, self.args.n_outputs, bias=False)

    def forward(self, x, t, p, epoch):
        if p is None:
            self.act['Lin1'] = x
            x = self.lin1(x, t, None, epoch)
            x = F.relu(x)
            self.act['Lin2'] = x
            x = self.lin2(x, t, None, epoch)
            x = F.relu(x)
            self.act['fc1'] = x
            x = self.fc1(x, t, None, epoch)
        else:
            self.act['Lin1'] = x
            x = self.lin1(x, t, p[0], epoch)
            x = F.relu(x)
            self.act['Lin2'] = x
            x = self.lin2(x, t, p[1], epoch)
            x = F.relu(x)
            self.act['fc1'] = x
            x = self.fc1(x, t, p[2], epoch)
        return x
    

class AlexNet(nn.Module):
    def __init__(self, args):
        self.args = args
        super(AlexNet, self).__init__()
        self.act = OrderedDict()
        self.map = []
        self.ksize = []
        self.in_channel = []
        self.map.append(32)
        self.conv1 = Conv2d(3, 64, 4, bias=False)
        s = compute_conv_output_size(32, 4)
        self.ln1 = nn.BatchNorm2d(64, track_running_stats=False)
        s = s // 2
        self.ksize.append(4)
        self.in_channel.append(3)
        self.map.append(s)
        self.conv2 = Conv2d(64, 128, 3, bias=False)
        s = compute_conv_output_size(s, 3)
        self.ln2 = nn.BatchNorm2d(128, track_running_stats=False)
        s = s // 2
        self.ksize.append(3)
        self.in_channel.append(64)
        self.map.append(s)
        self.conv3 = Conv2d(128, 256, 2, bias=False)
        s = compute_conv_output_size(s, 2)
        self.bn3 = nn.BatchNorm2d(256, track_running_stats=False)
        s = s // 2
        self.smid = s
        self.ksize.append(2)
        self.in_channel.append(128)
        self.map.append(256 * self.smid * self.smid)
        self.maxpool = torch.nn.MaxPool2d(2)
        self.relu = nn.ReLU()
        self.drop1 = torch.nn.Dropout(0.2)
        self.drop2 = torch.nn.Dropout(0.5)

        self.fc1 = Linear(256 * self.smid * self.smid, 2048, bias=False)
        self.ln4 = nn.BatchNorm1d(2048, track_running_stats=False)
        self.fc2 = Linear(2048, 2048, bias=False)
        self.ln5 = nn.BatchNorm1d(2048, track_running_stats=False)
        self.map.extend([2048])

        self.taskcla = self.args.taskcla
        self.fc3 = torch.nn.ModuleList()
        for t, n in self.taskcla:
            self.fc3.append(nn.Linear(2048, n, bias=False))

        self.backup = {}

    def forward(self,  x, t, p, epoch):
        if p is None:
            bsz = deepcopy(x.size(0))
            self.act['conv1'] = x
            x = self.conv1(x, t, None, epoch)
            x = self.maxpool(self.drop1(self.relu(self.ln1(x))))

            self.act['conv2'] = x
            x = self.conv2(x, t, None, epoch)
            x = self.maxpool(self.drop1(self.relu(self.ln2(x))))

            self.act['conv3'] = x
            x = self.conv3(x, t, None, epoch)
            x = self.maxpool(self.drop2(self.relu(self.bn3(x))))

            x = x.view(bsz, -1)
            self.act['fc1'] = x
            x = self.fc1(x, t, None, epoch)
            x = self.drop2(self.relu(self.ln4(x)))

            self.act['fc2'] = x
            x = self.fc2(x, t, None, epoch)
            x = self.drop2(self.relu(self.ln5(x)))
            y = []
            for t, i in self.taskcla:
                y.append(self.fc3[t](x))
        else:
            bsz = deepcopy(x.size(0))
            self.act['conv1'] = x
            x = self.conv1(x, t, p[0], epoch)
            x = self.maxpool(self.drop1(self.relu(self.ln1(x))))

            self.act['conv2'] = x
            x = self.conv2(x, t, p[1], epoch)
            x = self.maxpool(self.drop1(self.relu(self.ln2(x))))

            self.act['conv3'] = x
            x = self.conv3(x, t, p[2], epoch)
            x = self.maxpool(self.drop2(self.relu(self.bn3(x))))

            x = x.view(bsz, -1)
            self.act['fc1'] = x
            x = self.fc1(x, t, p[3], epoch)
            x = self.drop2(self.relu(self.ln4(x)))

            self.act['fc2'] = x
            x = self.fc2(x, t, p[4], epoch)
            x = self.drop2(self.relu(self.ln5(x)))
            y = []
            for t, i in self.taskcla:
                y.append(self.fc3[t](x))
        return y
    

class LeNet(nn.Module):
    def __init__(self, args):
        self.args = args
        super(LeNet, self).__init__()
        self.act = OrderedDict()
        self.map = []
        self.ksize = []
        self.in_channel = []

        self.map.append(32)
        self.conv1 = Conv2d(3, 20, 5, bias=False, padding=2)

        s = compute_conv_output_size(32, 5, 1, 2)
        s = compute_conv_output_size(s, 3, 2, 1)
        self.ksize.append(5)
        self.in_channel.append(3)
        self.map.append(s)
        self.conv2 = Conv2d(20, 50, 5, bias=False, padding=2)

        s = compute_conv_output_size(s, 5, 1, 2)
        s = compute_conv_output_size(s, 3, 2, 1)
        self.ksize.append(5)
        self.in_channel.append(20)
        self.smid = s
        self.map.append(50*self.smid*self.smid)
        self.maxpool = torch.nn.MaxPool2d(3, 2, padding=1)
        self.relu = torch.nn.ReLU()
        self.drop1 = torch.nn.Dropout(0)
        self.drop2 = torch.nn.Dropout(0)
        self.lrn = torch.nn.LocalResponseNorm(4, 0.001/9.0, 0.75, 1)

        self.fc1 = Linear(50*self.smid*self.smid, 800, bias=False)
        self.fc2 = Linear(800, 500, bias=False)
        self.map.extend([800])

        self.taskcla = self.args.taskcla
        self.fc3 = torch.nn.ModuleList()
        for t, n in self.taskcla:
            self.fc3.append(torch.nn.Linear(500, n, bias=False))

    def forward(self, x, t, p, epoch):
        if p is None:
            bsz = deepcopy(x.size(0))
            self.act['conv1'] = x
            x = self.conv1(x, t, None, epoch)
            x = self.maxpool(self.drop1(self.lrn(self.relu(x))))

            self.act['conv2'] = x
            x = self.conv2(x, t, None, epoch)
            x = self.maxpool(self.drop1(self.lrn(self.relu(x))))

            x = x.reshape(bsz, -1)
            self.act['fc1'] = x
            x = self.fc1(x, t, None, epoch)
            x = self.drop2(self.relu(x))

            self.act['fc2'] = x
            x = self.fc2(x, t, None, epoch)
            x = self.drop2(self.relu(x))

            y = []
            for t, i in self.taskcla:
                y.append(self.fc3[t](x))
        else:
            bsz = deepcopy(x.size(0))
            self.act['conv1'] = x
            x = self.conv1(x, t, p[0], epoch)
            x = self.maxpool(self.drop1(self.lrn(self.relu(x))))

            self.act['conv2'] = x
            x = self.conv2(x, t, p[1], epoch)
            x = self.maxpool(self.drop1(self.lrn(self.relu(x))))

            x = x.reshape(bsz, -1)
            self.act['fc1'] = x
            x = self.fc1(x, t, p[2], epoch)
            x = self.drop2(self.relu(x))

            self.act['fc2'] = x
            x = self.fc2(x, t, p[3], epoch)
            x = self.drop2(self.relu(x))

            y = []
            for t, i in self.taskcla:
                y.append(self.fc3[t](x))
        return y


###############resnet18
class Sequential(nn.Sequential):

    def __init__(self, *args):
        super(Sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

    def forward(self, input, t, p, epoch):
        for module in self:
            input = module(input, t, p, epoch)
        return input

class Shortcut(nn.Module):
    def __init__(self, stride, in_planes, expansion, planes):
        super(Shortcut, self).__init__()
        self.identity = True
        self.shortcut = Sequential()

        if stride != 1 or in_planes != expansion*planes:
            self.identity = False
            self.conv1 = nn.Conv2d(
                in_planes, expansion*planes, kernel_size=1, stride=stride, bias=False)
            self.bn1 = nn.ModuleList()
            for _ in range(20):
                self.bn1.append(nn.BatchNorm2d(
                    expansion*planes))

    def forward(self, x, t, p, epoch):
        if self.identity:
            out = self.shortcut(x, t, p, epoch)
        else:
            out = self.conv1(x)
            out = self.bn1[t](out)
        return out

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.ModuleList()
        for i in range(20):
            self.bn1.append(nn.BatchNorm2d(planes))
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.ModuleList()
        for i in range(20):
            self.bn2.append(nn.BatchNorm2d(planes))

        self.shortcut = Sequential()
        self.shortcut = Shortcut(
            stride=stride, in_planes=in_planes, expansion=self.expansion, planes=planes)
        self.act = OrderedDict()
        self.count = 0

    def forward(self, x, t, p, epoch):
        if p is None:
            self.count = self.count % 2
            self.act['conv_{}'.format(self.count)] = x
            self.count += 1
            out = relu(self.bn1[t](self.conv1(x, t, None, epoch)))
            self.count = self.count % 2
            self.act['conv_{}'.format(self.count)] = out
            self.count += 1
            out = self.bn2[t](self.conv2(out, t, None, epoch))
            out += self.shortcut(x, t, None, epoch)
            out = relu(out)
        else:
            self.count = self.count % 2
            self.act['conv_{}'.format(self.count)] = x
            self.count += 1
            out = relu(self.bn1[t](self.conv1(x, t, p[0], epoch)))
            self.count = self.count % 2
            self.act['conv_{}'.format(self.count)] = out
            self.count += 1
            out = self.bn2[t](self.conv2(out, t, p[1], epoch))
            out += self.shortcut(x, t, None, epoch)
            out = relu(out)
        return out



class ResNet(nn.Module):
    def __init__(self, block, num_blocks, taskcla, nf):
        super(ResNet, self).__init__()
        self.taskcla = taskcla
        self.in_planes = nf
        self.conv1 = conv3x3(3, nf * 1, 2)
        self.bn1 = nn.ModuleList()
        for t, n in self.taskcla:
            self.bn1.append(nn.BatchNorm2d(nf * 1))
        self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2)

        self.linear = torch.nn.ModuleList()
        for t, n in self.taskcla:
            self.linear.append(
                nn.Linear(nf * 8 * block.expansion * 9, n, bias=False))
        self.act = OrderedDict()

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return Sequential(*layers)

    def forward(self, x, t, p, epoch):
        if p is None:
            bsz = x.size(0)
            self.act['conv_in'] = x.view(bsz, 3, 84, 84)
            out = relu(self.bn1[t](self.conv1(
                x.view(bsz, 3, 84, 84), t, None, epoch)))
            out = self.layer1(out, t, None, epoch)
            out = self.layer2(out, t, None, epoch)
            out = self.layer3(out, t, None, epoch)
            out = self.layer4(out, t, None, epoch)
            out = avg_pool2d(out, 2)
            out = out.view(out.size(0), -1)
            y = []
            for t, i in self.taskcla:
                y.append(self.linear[t](out))
        else:
            bsz = x.size(0)
            self.act['conv_in'] = x.view(bsz, 3, 84, 84)
            out = relu(
                self.bn1[t](self.conv1(x.view(bsz, 3, 84, 84), t, p[0], epoch)))
            out = self.layer1[0](out, t, p[1:3], epoch)
            out = self.layer1[1](out, t, p[3:5], epoch)
            out = self.layer2[0](out, t, p[5:8], epoch)
            out = self.layer2[1](out, t, p[8:10], epoch)
            out = self.layer3[0](out, t, p[10:13], epoch)
            out = self.layer3[1](out, t, p[13:15], epoch)
            out = self.layer4[0](out, t, p[15:18], epoch)
            out = self.layer4[1](out, t, p[18:20], epoch)

            out = avg_pool2d(out, 2)
            out = out.view(out.size(0), -1)
            y = []
            for t, i in self.taskcla:
                y.append(self.linear[t](out))
        return y


def ResNet18(args):
    return ResNet(BasicBlock, [2, 2, 2, 2], args.taskcla, args.nf)